from dateutil.parser import parse
import pandas as pd
from hcquant.factor import winsorize
import numpy as np

def get_industry(trade_dt, engine):
    """
    wind底层数据库板
    返回时间截面上个股行业及其市值
    """
    trade_dt = parse(trade_dt).strftime('%Y%m%d') if isinstance(trade_dt, str) \
                else trade_dt.strftime('%Y%m%d')
    
    sql = """
    select a.S_INFO_WINDCODE sid, a.ENTRY_DT entry_dt, b.INDUSTRIESNAME indname
    FROM nwind.dbo.AShareSWIndustriesClass as a, nWind.dbo.ASHAREINDUSTRIESCODE as b
    where b.LEVELNUM = '2'
    and LEFT(a.SW_IND_CODE, 4) = LEFT(b.INDUSTRIESCODE, 4)
    ORDER by sid, ENTRY_DT
    """

    df = pd.read_sql(sql, engine)
    df1 = df.loc[df['entry_dt'] <= trade_dt].groupby(['sid']).last().indname.reset_index()
    sql1 = """
    SELECT S_INFO_WINDCODE sid, S_VAL_MV mkt_value from nWind.dbo.AShareEODDerivativeIndicator
    where TRADE_DT = '{}'
    order by S_INFO_WINDCODE
    """.format(trade_dt)

    mkt = pd.read_sql(sql1, engine)
    # wind的数据需要乘以10000倍
    mkt['mkt_value'] = np.log(mkt['mkt_value'] * 10000)

    df2 = mkt.merge(df1, on='sid')
    df2['mkt_value'] = winsorize(df2['mkt_value'])
    return df2